import pandas as pd
import numpy as np
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def feat(x,y,z):
    # 定义要拟合的曲面模型
    def surface_model(xy, a, b, c):
        x, y = xy
        return a * x ** 2 + b * y ** 2 + c

    params, covariance = curve_fit(surface_model, (x, y), z)

    # 获取拟合后的参数
    a, b, c = params

    # 生成拟合后的曲面数据
    x_fit = np.linspace(min(x), max(x), 100)
    y_fit = np.linspace(min(y), max(y), 100)
    x_fit, y_fit = np.meshgrid(x_fit, y_fit)
    z_fit = surface_model((x_fit, y_fit), a, b, c)
    return x_fit,y_fit,z_fit


if __name__ == '__main__':
    #D_LR = [0.00001,0.00002,0.00003,0.00004,0.00005,0.00006,0.00007,0.00008,0.00009,0.0001]
    #G_LR = [0.00001, 0.00002, 0.00003, 0.00004, 0.00005, 0.00006, 0.00007, 0.00008, 0.00009, 0.0001]
    G_LR = [0.00001, 0.00002, 0.00003, 0.00004, 0.00005, 0.00006, 0.00007, 0.00008, 0.00009, 0.0001, 0.0002, 0.0003,
            0.0004, 0.0005, 0.0006, 0.0007, 0.0008, 0.0009, 0.001,0.002,0.003,0.004,0.005]
    D_LR = [0.00001, 0.00002, 0.00003, 0.00004, 0.00005, 0.00006, 0.00007, 0.00008, 0.00009, 0.0001, 0.0002, 0.0003,
            0.0004, 0.0005, 0.0006, 0.0007, 0.0008, 0.0009, 0.001,0.002,0.003,0.004,0.005]
    Recall = [[] for i in range(len(G_LR))]
    np.set_printoptions(suppress=True)
    Recall_all = []
    appended = []
    x = []
    y = []
    epoch = 1
    # 打开txt文件进行读取
    with open('./first_test.txt', 'r') as file:
        for line in file:
            parts = line.split()
            test_recall = float(parts[3][:-1])
            Recall_all.append(test_recall)
    a=0
    test = Recall[1]
    for i,g_lr in enumerate(G_LR):
        for d_lr in D_LR:
            x.append(g_lr)
            y.append(d_lr)
            Recall[i].append(Recall_all[a])
            a +=1
    Recall_all = Recall_all
    x = np.array(x)
    y = np.array(y)
    Recall = np.array(Recall)
    print(Recall.shape)
    # 使用 curve_fit 进行曲面拟合
    x1,x2,x3,y1,y2,y3,z1,z2,z3 = [],[],[],[],[],[],[],[],[]
    for c in range(5):
        for d in range(5):
            x3.append((c+1)*0.001)
            y3.append((d+1)*0.001)
            z3.append(Recall[c-5][d-5])
    for a in range(10):
        for b in range(10):
            x1.append((a+1)*0.00001)
            y1.append((b+1)*0.0001)
            z1.append(Recall[a][b+9])
            x2.append((a+1)*0.0001)
            y2.append((b+1)*0.00001)
            z2.append(Recall[a+9][b])
    x1 = np.array(x1)
    x2 = np.array(x2)
    x3 = np.array(x3)
    y1 = np.array(y1)
    y2 = np.array(y2)
    y3 = np.array(y3)
    z1 = np.array(z1)
    z2 = np.array(z2)
    z3 = np.array(z3)
    print(x1.shape,y1.shape,z1.shape)
    x1_fit,y1_fit,z1_fit = feat(x1,y1,z1)
    x2_fit, y2_fit, z2_fit = feat(x2, y2, z2)
    x3_fit, y3_fit, z3_fit = feat(x3, y3, z3)
    # 绘制散点图
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(x1, y1, z1, c='r', marker='o', label='Sample Points')

    # 绘制拟合曲面
    #ax.plot_surface(x_fit, y_fit, z_fit, color='b', alpha=0.3, label='Fitted Surface')
    ax.plot_surface(x1_fit, y1_fit, z1_fit, alpha=0.5, rstride=100, cstride=100, color='b', label='Fitted Surface')

    ax.set_xlabel('G_lr')
    ax.set_ylabel('D_lr')
    ax.set_zlabel('Recall')


    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(x2, y2, z2, c='r', marker='o', label='Sample Points')

    # 绘制拟合曲面
    #ax.plot_surface(x_fit, y_fit, z_fit, color='b', alpha=0.3, label='Fitted Surface')
    ax.plot_surface(x2_fit, y2_fit, z2_fit, alpha=0.5, rstride=100, cstride=100, color='b', label='Fitted Surface')

    ax.set_xlabel('G_lr')
    ax.set_ylabel('D_lr')
    ax.set_zlabel('Recall')
    # ax.legend()

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(x3, y3, z3, c='r', marker='o', label='Sample Points')

    # 绘制拟合曲面
    # ax.plot_surface(x_fit, y_fit, z_fit, color='b', alpha=0.3, label='Fitted Surface')
    ax.plot_surface(x3_fit, y3_fit, z3_fit, alpha=0.5, rstride=100, cstride=100, color='b', label='Fitted Surface')

    ax.set_xlabel('G_lr')
    ax.set_ylabel('D_lr')
    ax.set_zlabel('Recall')
    plt.show()




    # fig = plt.figure()
    # ax = fig.add_subplot(111,projection='3d')
    # ax.scatter(x,y,Recall_all)
    # ax.set_xlabel('G_lr')
    # ax.set_ylabel('D_lr')
    # ax.set_zlabel('Recall')
    # plt.show()

